import * as tf from '@tensorflow/tfjs-core'
import * as box_1 from './box.js'

const ANCHORS_CONFIG = {
  'strides': [8, 16],
  'anchors': [2, 6]
}

const NUM_LANDMARKS = 6

function generateAnchors(width, height, outputSpec) {
  const anchors = []
  for (let i = 0; i < outputSpec.strides.length; i++) {
    const stride = outputSpec.strides[i]
    const gridRows = Math.floor((height + stride - 1) / stride)
    const gridCols = Math.floor((width + stride - 1) / stride)
    const anchorsNum = outputSpec.anchors[i]
    for (let gridY = 0; gridY < gridRows; gridY++) {
      const anchorY = stride * (gridY + 0.5)
      for (let gridX = 0; gridX < gridCols; gridX++) {
        const anchorX = stride * (gridX + 0.5)
        for (let n = 0; n < anchorsNum; n++) {
          anchors.push([anchorX, anchorY])
        }
      }
    }
  }
  return anchors
}

function decodeBounds(boxOutputs, anchors, inputSize) {
  const boxStarts = tf.slice(boxOutputs, [0, 1], [-1, 2])
  const centers = tf.add(boxStarts, anchors)
  const boxSizes = tf.slice(boxOutputs, [0, 3], [-1, 2])
  const boxSizesNormalized = tf.div(boxSizes, inputSize)
  const centersNormalized = tf.div(centers, inputSize)
  const halfBoxSize = tf.div(boxSizesNormalized, 2)
  const starts = tf.sub(centersNormalized, halfBoxSize)
  const ends = tf.add(centersNormalized, halfBoxSize)
  const startNormalized = tf.mul(starts, inputSize)
  const endNormalized = tf.mul(ends, inputSize)
  const concatAxis = 1
  return tf.concat2d([startNormalized, endNormalized], concatAxis)
}

export class BlazeFaceModel {
  constructor(model, width, height, maxFaces, iouThreshold, scoreThreshold) {
    this.blazeFaceModel = model
    this.width = width
    this.height = height
    this.maxFaces = maxFaces
    this.anchorsData = generateAnchors(width, height, ANCHORS_CONFIG)
    this.anchors = tf.tensor2d(this.anchorsData)
    this.inputSizeData = [width, height]
    this.inputSize = tf.tensor1d([width, height])
    this.iouThreshold = iouThreshold
    this.scoreThreshold = scoreThreshold
  }
  async getBoundingBoxes(inputImage, returnTensors) {
    const [detectedOutputs, boxes, scores] = tf.tidy(() => {
      const resizedImage = inputImage.resizeBilinear([this.width, this.height])
      const normalizedImage = tf.mul(tf.sub(resizedImage.div(255), 0.5), 2)
      const batchedPrediction = this.blazeFaceModel.predict(normalizedImage)
      const prediction = batchedPrediction.squeeze()
      const decodedBounds = decodeBounds(prediction, this.anchors, this.inputSize)
      const logits = tf.slice(prediction, [0, 0], [-1, 1])
      return [prediction, decodedBounds, tf.sigmoid(logits).squeeze()]
    })
    const boxIndicesTensor = await tf.image.nonMaxSuppressionAsync(boxes, scores, this.maxFaces, this.iouThreshold, this.scoreThreshold)
    const boxIndices = await boxIndicesTensor.array()
    boxIndicesTensor.dispose()
    let boundingBoxes = boxIndices.map((boxIndex) => tf.slice(boxes, [boxIndex, 0], [1, -1]))
    if (!returnTensors) {
      boundingBoxes = await Promise.all(boundingBoxes.map(async(boundingBox) => {
        const vals = await boundingBox.array()
        boundingBox.dispose()
        return vals
      }))
    }
    const originalHeight = inputImage.shape[1]
    const originalWidth = inputImage.shape[2]
    let scaleFactor
    if (returnTensors) {
      scaleFactor = tf.div([originalWidth, originalHeight], this.inputSize)
    } else {
      scaleFactor = [
        originalWidth / this.inputSizeData[0],
        originalHeight / this.inputSizeData[1]
      ]
    }
    const annotatedBoxes = boundingBoxes
      .map((boundingBox, i) => tf.tidy(() => {
        const boxIndex = boxIndices[i]
        let anchor
        if (returnTensors) {
          anchor = this.anchors.slice([boxIndex, 0], [1, 2])
        } else {
          anchor = this.anchorsData[boxIndex]
        }
        const box = boundingBox instanceof tf.Tensor
          ? box_1.createBox(boundingBox)
          : box_1.createBox(tf.tensor2d(boundingBox))
        const landmarks = tf.slice(detectedOutputs, [boxIndex, NUM_LANDMARKS - 1], [1, -1])
          .squeeze()
          .reshape([NUM_LANDMARKS, -1])
        const probability = tf.slice(scores, [boxIndex], [1])
        return { box, landmarks, probability, anchor }
      }))
    boxes.dispose()
    scores.dispose()
    detectedOutputs.dispose()
    return [annotatedBoxes, scaleFactor]
  }
  async estimateFaces(input, returnTensors = false) {
    const image = tf.tidy(() => {
      if (!(input instanceof tf.Tensor)) {
        input = tf.browser.fromPixels(input)
      }
      return input.toFloat().expandDims(0)
    })
    const [prediction, scaleFactor] = await this.getBoundingBoxes(image, returnTensors)
    image.dispose()
    if (returnTensors) {
      return prediction.map((face) => {
        const scaledBox = box_1.scaleBox(face.box, scaleFactor).startEndTensor.squeeze()
        return {
          topLeft: scaledBox.slice([0], [2]),
          bottomRight: scaledBox.slice([2], [2]),
          landmarks: face.landmarks.add(face.anchor).mul(scaleFactor),
          probability: face.probability
        }
      })
    }
    return Promise.all(prediction.map(async(face) => {
      const scaledBox = tf.tidy(() => {
        return box_1.scaleBox(face.box, scaleFactor).startEndTensor.squeeze()
      })
      const [landmarkData, boxData, probabilityData] = await Promise.all([face.landmarks, scaledBox, face.probability].map(async(d) => d.array()))
      const anchor = face.anchor
      const scaledLandmarks = landmarkData
        .map((landmark) => ([
          (landmark[0] + anchor[0]) * scaleFactor[0],
          (landmark[1] + anchor[1]) * scaleFactor[1]
        ]))
      scaledBox.dispose()
      box_1.disposeBox(face.box)
      face.landmarks.dispose()
      face.probability.dispose()
      return {
        topLeft: boxData.slice(0, 2),
        bottomRight: boxData.slice(2),
        landmarks: scaledLandmarks,
        probability: probabilityData
      }
    }))
  }
}
